{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UEBilEjLj5wY"
   },
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 119
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 536,
     "status": "ok",
     "timestamp": 1524974472601,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "GOzuY8Yvj5wb",
    "outputId": "c19362ce-f87a-4cc2-84cc-8d7b4b9e6007"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.6.8\n",
      "IPython 7.2.0\n",
      "\n",
      "torch 1.1.0\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "rH4XmErYj5wm"
   },
   "source": [
    "# ResNet-101 (on CIFAR-10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Network Architecture"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The network in this notebook is an implementation of the ResNet-101 [1] architecture on the CelebA face dataset [2] to train a gender classifier.  \n",
    "\n",
    "\n",
    "References\n",
    "    \n",
    "- [1] He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 770-778). ([CVPR Link](https://www.cv-foundation.org/openaccess/content_cvpr_2016/html/He_Deep_Residual_Learning_CVPR_2016_paper.html))\n",
    "\n",
    "- [2] Zhang, K., Tan, L., Li, Z., & Qiao, Y. (2016). Gender and smile classification using deep convolutional neural networks. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition Workshops (pp. 34-38).\n",
    "\n",
    "The ResNet-101 architecture is similar to the ResNet-50 architecture, which is in turn similar to the ResNet-34 architecture shown below (from [1]) except that the ResNet 101 is using a Bootleneck block (compared to ResNet-34) and more layers than ResNet-50 (figure shows a screenshot from [1]):\n",
    "\n",
    "\n",
    "![](../images/resnets/resnet101/resnet101-arch-1.png)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following figure illustrates residual blocks with skip connections such that the input passed via the shortcut matches the dimensions of the main path's output, which allows the network to learn identity functions.\n",
    "\n",
    "![](../images/resnets/resnet-ex-1-1.png)\n",
    "\n",
    "\n",
    "The ResNet-34 architecture actually uses residual blocks with modified skip connections such that the input passed via the shortcut matches is resized to dimensions of the main path's output. Such a residual block is illustrated below:\n",
    "\n",
    "![](../images/resnets/resnet-ex-1-2.png)\n",
    "\n",
    "The ResNet-50/101/151 then uses a bottleneck as shown below:\n",
    "\n",
    "![](../images/resnets/resnet-ex-1-3.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For a more detailed explanation see the other notebook, [resnet-ex-1.ipynb](resnet-ex-1.ipynb)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "MkoGLH_Tj5wn"
   },
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     }
    },
    "colab_type": "code",
    "id": "ORj09gnrj5wp"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from torch.utils.data import Dataset\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data.dataset import Subset\n",
    "\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "\n",
    "import time\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 85
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 23936,
     "status": "ok",
     "timestamp": 1524974497505,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "NnT0sZIwj5wu",
    "outputId": "55aed925-d17e-4c6a-8c71-0d9b3bde5637"
   },
   "outputs": [],
   "source": [
    "##########################\n",
    "### SETTINGS\n",
    "##########################\n",
    "\n",
    "# Hyperparameters\n",
    "RANDOM_SEED = 1\n",
    "LEARNING_RATE = 0.01\n",
    "NUM_EPOCHS = 50\n",
    "\n",
    "# Architecture\n",
    "NUM_CLASSES = 10\n",
    "BATCH_SIZE = 128\n",
    "DEVICE = torch.device('cuda:3')\n",
    "GRAYSCALE = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Image batch dimensions: torch.Size([128, 3, 32, 32])\n",
      "Image label dimensions: torch.Size([128])\n",
      "Image batch dimensions: torch.Size([128, 3, 32, 32])\n",
      "Image label dimensions: torch.Size([128])\n",
      "Image batch dimensions: torch.Size([128, 3, 32, 32])\n",
      "Image label dimensions: torch.Size([128])\n"
     ]
    }
   ],
   "source": [
    "##########################\n",
    "### CIFAR-10 Dataset\n",
    "##########################\n",
    "\n",
    "\n",
    "# Note transforms.ToTensor() scales input images\n",
    "# to 0-1 range\n",
    "\n",
    "\n",
    "train_indices = torch.arange(0, 49000)\n",
    "valid_indices = torch.arange(49000, 50000)\n",
    "\n",
    "\n",
    "train_and_valid = datasets.CIFAR10(root='data', \n",
    "                                   train=True, \n",
    "                                   transform=transforms.ToTensor(),\n",
    "                                   download=True)\n",
    "\n",
    "train_dataset = Subset(train_and_valid, train_indices)\n",
    "valid_dataset = Subset(train_and_valid, valid_indices)\n",
    "\n",
    "\n",
    "test_dataset = datasets.CIFAR10(root='data', \n",
    "                                train=False, \n",
    "                                transform=transforms.ToTensor())\n",
    "\n",
    "\n",
    "#####################################################\n",
    "### Data Loaders\n",
    "#####################################################\n",
    "\n",
    "train_loader = DataLoader(dataset=train_dataset, \n",
    "                          batch_size=BATCH_SIZE,\n",
    "                          num_workers=8,\n",
    "                          shuffle=True)\n",
    "\n",
    "valid_loader = DataLoader(dataset=valid_dataset, \n",
    "                          batch_size=BATCH_SIZE,\n",
    "                          num_workers=8,\n",
    "                          shuffle=False)\n",
    "\n",
    "test_loader = DataLoader(dataset=test_dataset, \n",
    "                         batch_size=BATCH_SIZE,\n",
    "                         num_workers=8,\n",
    "                         shuffle=False)\n",
    "\n",
    "#####################################################\n",
    "\n",
    "# Checking the dataset\n",
    "for images, labels in train_loader:  \n",
    "    print('Image batch dimensions:', images.shape)\n",
    "    print('Image label dimensions:', labels.shape)\n",
    "    break\n",
    "\n",
    "for images, labels in test_loader:  \n",
    "    print('Image batch dimensions:', images.shape)\n",
    "    print('Image label dimensions:', labels.shape)\n",
    "    break\n",
    "    \n",
    "for images, labels in valid_loader:  \n",
    "    print('Image batch dimensions:', images.shape)\n",
    "    print('Image label dimensions:', labels.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "I6hghKPxj5w0"
   },
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following code cell that implements the ResNet-34 architecture is a derivative of the code provided at https://pytorch.org/docs/0.4.0/_modules/torchvision/models/resnet.html."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################\n",
    "### MODEL\n",
    "##########################\n",
    "\n",
    "\n",
    "def conv3x3(in_planes, out_planes, stride=1):\n",
    "    \"\"\"3x3 convolution with padding\"\"\"\n",
    "    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n",
    "                     padding=1, bias=False)\n",
    "\n",
    "\n",
    "class Bottleneck(nn.Module):\n",
    "    expansion = 4\n",
    "\n",
    "    def __init__(self, inplanes, planes, stride=1, downsample=None):\n",
    "        super(Bottleneck, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(planes)\n",
    "        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,\n",
    "                               padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm2d(planes)\n",
    "        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)\n",
    "        self.bn3 = nn.BatchNorm2d(planes * 4)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.downsample = downsample\n",
    "        self.stride = stride\n",
    "\n",
    "    def forward(self, x):\n",
    "        residual = x\n",
    "\n",
    "        out = self.conv1(x)\n",
    "        out = self.bn1(out)\n",
    "        out = self.relu(out)\n",
    "\n",
    "        out = self.conv2(out)\n",
    "        out = self.bn2(out)\n",
    "        out = self.relu(out)\n",
    "\n",
    "        out = self.conv3(out)\n",
    "        out = self.bn3(out)\n",
    "\n",
    "        if self.downsample is not None:\n",
    "            residual = self.downsample(x)\n",
    "\n",
    "        out += residual\n",
    "        out = self.relu(out)\n",
    "\n",
    "        return out\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "class ResNet(nn.Module):\n",
    "\n",
    "    def __init__(self, block, layers, num_classes, grayscale):\n",
    "        self.inplanes = 64\n",
    "        if grayscale:\n",
    "            in_dim = 1\n",
    "        else:\n",
    "            in_dim = 3\n",
    "        super(ResNet, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(in_dim, 64, kernel_size=7, stride=2, padding=3,\n",
    "                               bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(64)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
    "        self.layer1 = self._make_layer(block, 64, layers[0])\n",
    "        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n",
    "        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n",
    "        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n",
    "        self.avgpool = nn.AvgPool2d(7, stride=1, padding=2)\n",
    "        #self.fc = nn.Linear(2048 * block.expansion, num_classes)\n",
    "        self.fc = nn.Linear(2048, num_classes)\n",
    "\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.Conv2d):\n",
    "                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n",
    "                m.weight.data.normal_(0, (2. / n)**.5)\n",
    "            elif isinstance(m, nn.BatchNorm2d):\n",
    "                m.weight.data.fill_(1)\n",
    "                m.bias.data.zero_()\n",
    "\n",
    "    def _make_layer(self, block, planes, blocks, stride=1):\n",
    "        downsample = None\n",
    "        if stride != 1 or self.inplanes != planes * block.expansion:\n",
    "            downsample = nn.Sequential(\n",
    "                nn.Conv2d(self.inplanes, planes * block.expansion,\n",
    "                          kernel_size=1, stride=stride, bias=False),\n",
    "                nn.BatchNorm2d(planes * block.expansion),\n",
    "            )\n",
    "\n",
    "        layers = []\n",
    "        layers.append(block(self.inplanes, planes, stride, downsample))\n",
    "        self.inplanes = planes * block.expansion\n",
    "        for i in range(1, blocks):\n",
    "            layers.append(block(self.inplanes, planes))\n",
    "\n",
    "        return nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = self.bn1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.maxpool(x)\n",
    "\n",
    "        x = self.layer1(x)\n",
    "        x = self.layer2(x)\n",
    "        x = self.layer3(x)\n",
    "        x = self.layer4(x)\n",
    "\n",
    "        #x = self.avgpool(x)\n",
    "        x = x.view(x.size(0), -1)\n",
    "        logits = self.fc(x)\n",
    "        probas = F.softmax(logits, dim=1)\n",
    "        return logits, probas\n",
    "\n",
    "\n",
    "\n",
    "def resnet101(num_classes, grayscale):\n",
    "    \"\"\"Constructs a ResNet-101 model.\"\"\"\n",
    "    model = ResNet(block=Bottleneck, \n",
    "                   layers=[3, 4, 23, 3],\n",
    "                   num_classes=NUM_CLASSES,\n",
    "                   grayscale=grayscale)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     }
    },
    "colab_type": "code",
    "id": "_lza9t_uj5w1"
   },
   "outputs": [],
   "source": [
    "torch.manual_seed(RANDOM_SEED)\n",
    "\n",
    "##########################\n",
    "### COST AND OPTIMIZER\n",
    "##########################\n",
    "\n",
    "model = resnet101(NUM_CLASSES, GRAYSCALE)\n",
    "model.to(DEVICE)\n",
    " \n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "RAodboScj5w6"
   },
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 1547
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 2384585,
     "status": "ok",
     "timestamp": 1524976888520,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "Dzh3ROmRj5w7",
    "outputId": "5f8fd8c9-b076-403a-b0b7-fd2d498b48d7",
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/050 | Batch 000/383 | Cost: 2.5163\n",
      "Epoch: 001/050 | Batch 120/383 | Cost: 2.6981\n",
      "Epoch: 001/050 | Batch 240/383 | Cost: 2.4884\n",
      "Epoch: 001/050 | Batch 360/383 | Cost: 2.2649\n",
      "Epoch: 001/050 Train Acc.: 22.23% | Validation Acc.: 20.20%\n",
      "Time elapsed: 0.91 min\n",
      "Epoch: 002/050 | Batch 000/383 | Cost: 2.0539\n",
      "Epoch: 002/050 | Batch 120/383 | Cost: 1.9494\n",
      "Epoch: 002/050 | Batch 240/383 | Cost: 1.7358\n",
      "Epoch: 002/050 | Batch 360/383 | Cost: 1.6881\n",
      "Epoch: 002/050 Train Acc.: 34.62% | Validation Acc.: 33.30%\n",
      "Time elapsed: 1.82 min\n",
      "Epoch: 003/050 | Batch 000/383 | Cost: 1.5947\n",
      "Epoch: 003/050 | Batch 120/383 | Cost: 1.6122\n",
      "Epoch: 003/050 | Batch 240/383 | Cost: 1.6285\n",
      "Epoch: 003/050 | Batch 360/383 | Cost: 1.5403\n",
      "Epoch: 003/050 Train Acc.: 44.03% | Validation Acc.: 44.20%\n",
      "Time elapsed: 2.75 min\n",
      "Epoch: 004/050 | Batch 000/383 | Cost: 1.4653\n",
      "Epoch: 004/050 | Batch 120/383 | Cost: 1.3565\n",
      "Epoch: 004/050 | Batch 240/383 | Cost: 1.4571\n",
      "Epoch: 004/050 | Batch 360/383 | Cost: 1.2938\n",
      "Epoch: 004/050 Train Acc.: 52.26% | Validation Acc.: 51.60%\n",
      "Time elapsed: 3.68 min\n",
      "Epoch: 005/050 | Batch 000/383 | Cost: 1.3441\n",
      "Epoch: 005/050 | Batch 120/383 | Cost: 1.2296\n",
      "Epoch: 005/050 | Batch 240/383 | Cost: 1.1155\n",
      "Epoch: 005/050 | Batch 360/383 | Cost: 1.2562\n",
      "Epoch: 005/050 Train Acc.: 57.62% | Validation Acc.: 56.30%\n",
      "Time elapsed: 4.62 min\n",
      "Epoch: 006/050 | Batch 000/383 | Cost: 1.0643\n",
      "Epoch: 006/050 | Batch 120/383 | Cost: 1.1455\n",
      "Epoch: 006/050 | Batch 240/383 | Cost: 1.2560\n",
      "Epoch: 006/050 | Batch 360/383 | Cost: 1.1945\n",
      "Epoch: 006/050 Train Acc.: 61.26% | Validation Acc.: 57.10%\n",
      "Time elapsed: 5.55 min\n",
      "Epoch: 007/050 | Batch 000/383 | Cost: 1.0354\n",
      "Epoch: 007/050 | Batch 120/383 | Cost: 1.0614\n",
      "Epoch: 007/050 | Batch 240/383 | Cost: 0.8850\n",
      "Epoch: 007/050 | Batch 360/383 | Cost: 1.2730\n",
      "Epoch: 007/050 Train Acc.: 65.91% | Validation Acc.: 61.50%\n",
      "Time elapsed: 6.48 min\n",
      "Epoch: 008/050 | Batch 000/383 | Cost: 0.8989\n",
      "Epoch: 008/050 | Batch 120/383 | Cost: 0.9142\n",
      "Epoch: 008/050 | Batch 240/383 | Cost: 1.1074\n",
      "Epoch: 008/050 | Batch 360/383 | Cost: 0.8856\n",
      "Epoch: 008/050 Train Acc.: 70.48% | Validation Acc.: 64.60%\n",
      "Time elapsed: 7.40 min\n",
      "Epoch: 009/050 | Batch 000/383 | Cost: 0.9335\n",
      "Epoch: 009/050 | Batch 120/383 | Cost: 0.9332\n",
      "Epoch: 009/050 | Batch 240/383 | Cost: 0.8908\n",
      "Epoch: 009/050 | Batch 360/383 | Cost: 0.6551\n",
      "Epoch: 009/050 Train Acc.: 74.30% | Validation Acc.: 68.60%\n",
      "Time elapsed: 8.33 min\n",
      "Epoch: 010/050 | Batch 000/383 | Cost: 0.7369\n",
      "Epoch: 010/050 | Batch 120/383 | Cost: 1.1510\n",
      "Epoch: 010/050 | Batch 240/383 | Cost: 1.0534\n",
      "Epoch: 010/050 | Batch 360/383 | Cost: 1.6784\n",
      "Epoch: 010/050 Train Acc.: 43.73% | Validation Acc.: 46.20%\n",
      "Time elapsed: 9.25 min\n",
      "Epoch: 011/050 | Batch 000/383 | Cost: 1.6847\n",
      "Epoch: 011/050 | Batch 120/383 | Cost: 1.2999\n",
      "Epoch: 011/050 | Batch 240/383 | Cost: 1.3228\n",
      "Epoch: 011/050 | Batch 360/383 | Cost: 1.2020\n",
      "Epoch: 011/050 Train Acc.: 64.80% | Validation Acc.: 63.40%\n",
      "Time elapsed: 10.19 min\n",
      "Epoch: 012/050 | Batch 000/383 | Cost: 0.8810\n",
      "Epoch: 012/050 | Batch 120/383 | Cost: 1.0206\n",
      "Epoch: 012/050 | Batch 240/383 | Cost: 0.9342\n",
      "Epoch: 012/050 | Batch 360/383 | Cost: 0.9090\n",
      "Epoch: 012/050 Train Acc.: 72.51% | Validation Acc.: 69.50%\n",
      "Time elapsed: 11.12 min\n",
      "Epoch: 013/050 | Batch 000/383 | Cost: 0.7491\n",
      "Epoch: 013/050 | Batch 120/383 | Cost: 0.7770\n",
      "Epoch: 013/050 | Batch 240/383 | Cost: 0.6177\n",
      "Epoch: 013/050 | Batch 360/383 | Cost: 0.8193\n",
      "Epoch: 013/050 Train Acc.: 77.19% | Validation Acc.: 71.10%\n",
      "Time elapsed: 12.05 min\n",
      "Epoch: 014/050 | Batch 000/383 | Cost: 0.7384\n",
      "Epoch: 014/050 | Batch 120/383 | Cost: 0.7387\n",
      "Epoch: 014/050 | Batch 240/383 | Cost: 0.8657\n",
      "Epoch: 014/050 | Batch 360/383 | Cost: 0.8790\n",
      "Epoch: 014/050 Train Acc.: 76.11% | Validation Acc.: 70.50%\n",
      "Time elapsed: 12.98 min\n",
      "Epoch: 015/050 | Batch 000/383 | Cost: 0.8404\n",
      "Epoch: 015/050 | Batch 120/383 | Cost: 0.5830\n",
      "Epoch: 015/050 | Batch 240/383 | Cost: 0.5412\n",
      "Epoch: 015/050 | Batch 360/383 | Cost: 0.5490\n",
      "Epoch: 015/050 Train Acc.: 81.81% | Validation Acc.: 73.10%\n",
      "Time elapsed: 13.91 min\n",
      "Epoch: 016/050 | Batch 000/383 | Cost: 0.4776\n",
      "Epoch: 016/050 | Batch 120/383 | Cost: 0.6313\n",
      "Epoch: 016/050 | Batch 240/383 | Cost: 0.6662\n",
      "Epoch: 016/050 | Batch 360/383 | Cost: 1.2079\n",
      "Epoch: 016/050 Train Acc.: 70.04% | Validation Acc.: 65.80%\n",
      "Time elapsed: 14.85 min\n",
      "Epoch: 017/050 | Batch 000/383 | Cost: 0.8772\n",
      "Epoch: 017/050 | Batch 120/383 | Cost: 0.5628\n",
      "Epoch: 017/050 | Batch 240/383 | Cost: 0.6132\n",
      "Epoch: 017/050 | Batch 360/383 | Cost: 0.5744\n",
      "Epoch: 017/050 Train Acc.: 85.69% | Validation Acc.: 76.20%\n",
      "Time elapsed: 15.78 min\n",
      "Epoch: 018/050 | Batch 000/383 | Cost: 0.2691\n",
      "Epoch: 018/050 | Batch 120/383 | Cost: 0.4488\n",
      "Epoch: 018/050 | Batch 240/383 | Cost: 0.4294\n",
      "Epoch: 018/050 | Batch 360/383 | Cost: 0.4160\n",
      "Epoch: 018/050 Train Acc.: 85.21% | Validation Acc.: 74.30%\n",
      "Time elapsed: 16.70 min\n",
      "Epoch: 019/050 | Batch 000/383 | Cost: 0.3669\n",
      "Epoch: 019/050 | Batch 120/383 | Cost: 0.4060\n",
      "Epoch: 019/050 | Batch 240/383 | Cost: 0.3356\n",
      "Epoch: 019/050 | Batch 360/383 | Cost: 0.4026\n",
      "Epoch: 019/050 Train Acc.: 88.68% | Validation Acc.: 75.60%\n",
      "Time elapsed: 17.63 min\n",
      "Epoch: 020/050 | Batch 000/383 | Cost: 0.3351\n",
      "Epoch: 020/050 | Batch 120/383 | Cost: 0.2692\n",
      "Epoch: 020/050 | Batch 240/383 | Cost: 0.4012\n",
      "Epoch: 020/050 | Batch 360/383 | Cost: 0.4054\n",
      "Epoch: 020/050 Train Acc.: 78.68% | Validation Acc.: 69.30%\n",
      "Time elapsed: 18.56 min\n",
      "Epoch: 021/050 | Batch 000/383 | Cost: 0.6588\n",
      "Epoch: 021/050 | Batch 120/383 | Cost: 0.2520\n",
      "Epoch: 021/050 | Batch 240/383 | Cost: 0.3063\n",
      "Epoch: 021/050 | Batch 360/383 | Cost: 0.3941\n",
      "Epoch: 021/050 Train Acc.: 91.77% | Validation Acc.: 75.10%\n",
      "Time elapsed: 19.46 min\n",
      "Epoch: 022/050 | Batch 000/383 | Cost: 0.1590\n",
      "Epoch: 022/050 | Batch 120/383 | Cost: 0.2021\n",
      "Epoch: 022/050 | Batch 240/383 | Cost: 0.2791\n",
      "Epoch: 022/050 | Batch 360/383 | Cost: 0.3070\n",
      "Epoch: 022/050 Train Acc.: 94.08% | Validation Acc.: 75.00%\n",
      "Time elapsed: 20.41 min\n",
      "Epoch: 023/050 | Batch 000/383 | Cost: 0.2246\n",
      "Epoch: 023/050 | Batch 120/383 | Cost: 0.1973\n",
      "Epoch: 023/050 | Batch 240/383 | Cost: 0.3057\n",
      "Epoch: 023/050 | Batch 360/383 | Cost: 0.3288\n",
      "Epoch: 023/050 Train Acc.: 94.23% | Validation Acc.: 75.90%\n",
      "Time elapsed: 21.34 min\n",
      "Epoch: 024/050 | Batch 000/383 | Cost: 0.1660\n",
      "Epoch: 024/050 | Batch 120/383 | Cost: 0.2438\n",
      "Epoch: 024/050 | Batch 240/383 | Cost: 0.1335\n",
      "Epoch: 024/050 | Batch 360/383 | Cost: 0.2232\n",
      "Epoch: 024/050 Train Acc.: 95.17% | Validation Acc.: 77.60%\n",
      "Time elapsed: 22.29 min\n",
      "Epoch: 025/050 | Batch 000/383 | Cost: 0.0962\n",
      "Epoch: 025/050 | Batch 120/383 | Cost: 0.2140\n",
      "Epoch: 025/050 | Batch 240/383 | Cost: 0.3161\n",
      "Epoch: 025/050 | Batch 360/383 | Cost: 0.2601\n",
      "Epoch: 025/050 Train Acc.: 95.61% | Validation Acc.: 75.50%\n",
      "Time elapsed: 23.22 min\n",
      "Epoch: 026/050 | Batch 000/383 | Cost: 0.1091\n",
      "Epoch: 026/050 | Batch 120/383 | Cost: 0.1121\n",
      "Epoch: 026/050 | Batch 240/383 | Cost: 0.2550\n",
      "Epoch: 026/050 | Batch 360/383 | Cost: 0.1442\n",
      "Epoch: 026/050 Train Acc.: 95.77% | Validation Acc.: 75.80%\n",
      "Time elapsed: 24.14 min\n",
      "Epoch: 027/050 | Batch 000/383 | Cost: 0.1545\n",
      "Epoch: 027/050 | Batch 120/383 | Cost: 0.1142\n",
      "Epoch: 027/050 | Batch 240/383 | Cost: 0.1245\n",
      "Epoch: 027/050 | Batch 360/383 | Cost: 0.1778\n",
      "Epoch: 027/050 Train Acc.: 96.51% | Validation Acc.: 76.60%\n",
      "Time elapsed: 25.07 min\n",
      "Epoch: 028/050 | Batch 000/383 | Cost: 0.0774\n",
      "Epoch: 028/050 | Batch 120/383 | Cost: 0.0916\n",
      "Epoch: 028/050 | Batch 240/383 | Cost: 0.2275\n",
      "Epoch: 028/050 | Batch 360/383 | Cost: 0.0742\n",
      "Epoch: 028/050 Train Acc.: 95.90% | Validation Acc.: 77.10%\n",
      "Time elapsed: 26.00 min\n",
      "Epoch: 029/050 | Batch 000/383 | Cost: 0.0556\n",
      "Epoch: 029/050 | Batch 120/383 | Cost: 0.0649\n",
      "Epoch: 029/050 | Batch 240/383 | Cost: 0.1699\n",
      "Epoch: 029/050 | Batch 360/383 | Cost: 0.0963\n",
      "Epoch: 029/050 Train Acc.: 95.77% | Validation Acc.: 74.80%\n",
      "Time elapsed: 26.93 min\n",
      "Epoch: 030/050 | Batch 000/383 | Cost: 0.2278\n",
      "Epoch: 030/050 | Batch 120/383 | Cost: 0.1565\n",
      "Epoch: 030/050 | Batch 240/383 | Cost: 0.0929\n",
      "Epoch: 030/050 | Batch 360/383 | Cost: 1.0334\n",
      "Epoch: 030/050 Train Acc.: 53.82% | Validation Acc.: 52.10%\n",
      "Time elapsed: 27.85 min\n",
      "Epoch: 031/050 | Batch 000/383 | Cost: 1.5570\n",
      "Epoch: 031/050 | Batch 120/383 | Cost: 0.6029\n",
      "Epoch: 031/050 | Batch 240/383 | Cost: 0.4034\n",
      "Epoch: 031/050 | Batch 360/383 | Cost: 0.3380\n",
      "Epoch: 031/050 Train Acc.: 94.16% | Validation Acc.: 73.70%\n",
      "Time elapsed: 28.78 min\n",
      "Epoch: 032/050 | Batch 000/383 | Cost: 0.1454\n",
      "Epoch: 032/050 | Batch 120/383 | Cost: 0.1692\n",
      "Epoch: 032/050 | Batch 240/383 | Cost: 0.0922\n",
      "Epoch: 032/050 | Batch 360/383 | Cost: 0.1078\n",
      "Epoch: 032/050 Train Acc.: 97.44% | Validation Acc.: 77.20%\n",
      "Time elapsed: 29.71 min\n",
      "Epoch: 033/050 | Batch 000/383 | Cost: 0.1039\n",
      "Epoch: 033/050 | Batch 120/383 | Cost: 0.0764\n",
      "Epoch: 033/050 | Batch 240/383 | Cost: 0.1007\n",
      "Epoch: 033/050 | Batch 360/383 | Cost: 0.0518\n",
      "Epoch: 033/050 Train Acc.: 97.78% | Validation Acc.: 76.20%\n",
      "Time elapsed: 30.64 min\n",
      "Epoch: 034/050 | Batch 000/383 | Cost: 0.0672\n",
      "Epoch: 034/050 | Batch 120/383 | Cost: 0.0719\n",
      "Epoch: 034/050 | Batch 240/383 | Cost: 0.1163\n",
      "Epoch: 034/050 | Batch 360/383 | Cost: 0.1522\n",
      "Epoch: 034/050 Train Acc.: 97.79% | Validation Acc.: 75.80%\n",
      "Time elapsed: 31.58 min\n",
      "Epoch: 035/050 | Batch 000/383 | Cost: 0.1177\n",
      "Epoch: 035/050 | Batch 120/383 | Cost: 0.0802\n",
      "Epoch: 035/050 | Batch 240/383 | Cost: 0.1278\n",
      "Epoch: 035/050 | Batch 360/383 | Cost: 0.0857\n",
      "Epoch: 035/050 Train Acc.: 97.90% | Validation Acc.: 76.40%\n",
      "Time elapsed: 32.52 min\n",
      "Epoch: 036/050 | Batch 000/383 | Cost: 0.0574\n",
      "Epoch: 036/050 | Batch 120/383 | Cost: 0.0644\n",
      "Epoch: 036/050 | Batch 240/383 | Cost: 0.1070\n",
      "Epoch: 036/050 | Batch 360/383 | Cost: 0.0326\n",
      "Epoch: 036/050 Train Acc.: 97.71% | Validation Acc.: 75.60%\n",
      "Time elapsed: 33.43 min\n",
      "Epoch: 037/050 | Batch 000/383 | Cost: 0.0406\n",
      "Epoch: 037/050 | Batch 120/383 | Cost: 0.0697\n",
      "Epoch: 037/050 | Batch 240/383 | Cost: 0.0651\n",
      "Epoch: 037/050 | Batch 360/383 | Cost: 0.0908\n",
      "Epoch: 037/050 Train Acc.: 97.94% | Validation Acc.: 77.20%\n",
      "Time elapsed: 34.37 min\n",
      "Epoch: 038/050 | Batch 000/383 | Cost: 0.0772\n",
      "Epoch: 038/050 | Batch 120/383 | Cost: 0.0609\n",
      "Epoch: 038/050 | Batch 240/383 | Cost: 0.1069\n",
      "Epoch: 038/050 | Batch 360/383 | Cost: 0.0757\n",
      "Epoch: 038/050 Train Acc.: 98.20% | Validation Acc.: 76.80%\n",
      "Time elapsed: 35.31 min\n",
      "Epoch: 039/050 | Batch 000/383 | Cost: 0.0176\n",
      "Epoch: 039/050 | Batch 120/383 | Cost: 0.0788\n",
      "Epoch: 039/050 | Batch 240/383 | Cost: 0.1234\n",
      "Epoch: 039/050 | Batch 360/383 | Cost: 0.0626\n",
      "Epoch: 039/050 Train Acc.: 97.88% | Validation Acc.: 76.70%\n",
      "Time elapsed: 36.24 min\n",
      "Epoch: 040/050 | Batch 000/383 | Cost: 0.1171\n",
      "Epoch: 040/050 | Batch 120/383 | Cost: 0.0533\n",
      "Epoch: 040/050 | Batch 240/383 | Cost: 0.1050\n",
      "Epoch: 040/050 | Batch 360/383 | Cost: 0.0686\n",
      "Epoch: 040/050 Train Acc.: 98.21% | Validation Acc.: 75.10%\n",
      "Time elapsed: 37.17 min\n",
      "Epoch: 041/050 | Batch 000/383 | Cost: 0.0568\n",
      "Epoch: 041/050 | Batch 120/383 | Cost: 0.0160\n",
      "Epoch: 041/050 | Batch 240/383 | Cost: 0.0414\n",
      "Epoch: 041/050 | Batch 360/383 | Cost: 0.1025\n",
      "Epoch: 041/050 Train Acc.: 98.27% | Validation Acc.: 77.00%\n",
      "Time elapsed: 38.09 min\n",
      "Epoch: 042/050 | Batch 000/383 | Cost: 0.0302\n",
      "Epoch: 042/050 | Batch 120/383 | Cost: 0.0280\n",
      "Epoch: 042/050 | Batch 240/383 | Cost: 0.0703\n",
      "Epoch: 042/050 | Batch 360/383 | Cost: 0.0316\n",
      "Epoch: 042/050 Train Acc.: 98.09% | Validation Acc.: 75.70%\n",
      "Time elapsed: 39.02 min\n",
      "Epoch: 043/050 | Batch 000/383 | Cost: 0.0589\n",
      "Epoch: 043/050 | Batch 120/383 | Cost: 0.0294\n",
      "Epoch: 043/050 | Batch 240/383 | Cost: 0.0760\n",
      "Epoch: 043/050 | Batch 360/383 | Cost: 0.0859\n",
      "Epoch: 043/050 Train Acc.: 98.11% | Validation Acc.: 75.90%\n",
      "Time elapsed: 39.96 min\n",
      "Epoch: 044/050 | Batch 000/383 | Cost: 0.0466\n",
      "Epoch: 044/050 | Batch 120/383 | Cost: 0.0802\n",
      "Epoch: 044/050 | Batch 240/383 | Cost: 0.0781\n",
      "Epoch: 044/050 | Batch 360/383 | Cost: 0.0717\n",
      "Epoch: 044/050 Train Acc.: 98.52% | Validation Acc.: 77.00%\n",
      "Time elapsed: 40.89 min\n",
      "Epoch: 045/050 | Batch 000/383 | Cost: 0.0566\n",
      "Epoch: 045/050 | Batch 120/383 | Cost: 0.0935\n",
      "Epoch: 045/050 | Batch 240/383 | Cost: 0.0267\n",
      "Epoch: 045/050 | Batch 360/383 | Cost: 0.0600\n",
      "Epoch: 045/050 Train Acc.: 98.18% | Validation Acc.: 77.90%\n",
      "Time elapsed: 41.83 min\n",
      "Epoch: 046/050 | Batch 000/383 | Cost: 0.0506\n",
      "Epoch: 046/050 | Batch 120/383 | Cost: 0.0650\n",
      "Epoch: 046/050 | Batch 240/383 | Cost: 0.0094\n",
      "Epoch: 046/050 | Batch 360/383 | Cost: 0.0197\n",
      "Epoch: 046/050 Train Acc.: 98.54% | Validation Acc.: 77.50%\n",
      "Time elapsed: 42.76 min\n",
      "Epoch: 047/050 | Batch 000/383 | Cost: 0.0694\n",
      "Epoch: 047/050 | Batch 120/383 | Cost: 0.0149\n",
      "Epoch: 047/050 | Batch 240/383 | Cost: 0.0602\n",
      "Epoch: 047/050 | Batch 360/383 | Cost: 0.0592\n",
      "Epoch: 047/050 Train Acc.: 98.01% | Validation Acc.: 77.80%\n",
      "Time elapsed: 43.69 min\n",
      "Epoch: 048/050 | Batch 000/383 | Cost: 0.0822\n",
      "Epoch: 048/050 | Batch 120/383 | Cost: 0.1484\n",
      "Epoch: 048/050 | Batch 240/383 | Cost: 0.0758\n",
      "Epoch: 048/050 | Batch 360/383 | Cost: 0.0467\n",
      "Epoch: 048/050 Train Acc.: 98.75% | Validation Acc.: 76.40%\n",
      "Time elapsed: 44.62 min\n",
      "Epoch: 049/050 | Batch 000/383 | Cost: 0.0427\n",
      "Epoch: 049/050 | Batch 120/383 | Cost: 0.0119\n",
      "Epoch: 049/050 | Batch 240/383 | Cost: 0.0414\n",
      "Epoch: 049/050 | Batch 360/383 | Cost: 0.0574\n",
      "Epoch: 049/050 Train Acc.: 98.89% | Validation Acc.: 77.20%\n",
      "Time elapsed: 45.57 min\n",
      "Epoch: 050/050 | Batch 000/383 | Cost: 0.0492\n",
      "Epoch: 050/050 | Batch 120/383 | Cost: 0.0210\n",
      "Epoch: 050/050 | Batch 240/383 | Cost: 0.0981\n",
      "Epoch: 050/050 | Batch 360/383 | Cost: 0.0344\n",
      "Epoch: 050/050 Train Acc.: 98.27% | Validation Acc.: 76.50%\n",
      "Time elapsed: 46.50 min\n",
      "Total Training Time: 46.50 min\n"
     ]
    }
   ],
   "source": [
    "def compute_accuracy(model, data_loader, device):\n",
    "    correct_pred, num_examples = 0, 0\n",
    "    for i, (features, targets) in enumerate(data_loader):\n",
    "            \n",
    "        features = features.to(device)\n",
    "        targets = targets.to(device)\n",
    "\n",
    "        logits, probas = model(features)\n",
    "        _, predicted_labels = torch.max(probas, 1)\n",
    "        num_examples += targets.size(0)\n",
    "        correct_pred += (predicted_labels == targets).sum()\n",
    "    return correct_pred.float()/num_examples * 100\n",
    "    \n",
    "\n",
    "start_time = time.time()\n",
    "\n",
    "# use random seed for reproducibility (here batch shuffling)\n",
    "torch.manual_seed(RANDOM_SEED)\n",
    "\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    \n",
    "    model.train()\n",
    "    \n",
    "    for batch_idx, (features, targets) in enumerate(train_loader):\n",
    "    \n",
    "        ### PREPARE MINIBATCH\n",
    "        features = features.to(DEVICE)\n",
    "        targets = targets.to(DEVICE)\n",
    "            \n",
    "        ### FORWARD AND BACK PROP\n",
    "        logits, probas = model(features)\n",
    "        cost = F.cross_entropy(logits, targets)\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        cost.backward()\n",
    "        \n",
    "        ### UPDATE MODEL PARAMETERS\n",
    "        optimizer.step()\n",
    "        \n",
    "        ### LOGGING\n",
    "        if not batch_idx % 120:\n",
    "            print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '\n",
    "                   f'Batch {batch_idx:03d}/{len(train_loader):03d} |' \n",
    "                   f' Cost: {cost:.4f}')\n",
    "\n",
    "    # no need to build the computation graph for backprop when computing accuracy\n",
    "    with torch.set_grad_enabled(False):\n",
    "        train_acc = compute_accuracy(model, train_loader, device=DEVICE)\n",
    "        valid_acc = compute_accuracy(model, valid_loader, device=DEVICE)\n",
    "        print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} Train Acc.: {train_acc:.2f}%'\n",
    "              f' | Validation Acc.: {valid_acc:.2f}%')\n",
    "        \n",
    "    elapsed = (time.time() - start_time)/60\n",
    "    print(f'Time elapsed: {elapsed:.2f} min')\n",
    "  \n",
    "elapsed = (time.time() - start_time)/60\n",
    "print(f'Total Training Time: {elapsed:.2f} min')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "paaeEQHQj5xC"
   },
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 6514,
     "status": "ok",
     "timestamp": 1524976895054,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "gzQMWKq5j5xE",
    "outputId": "de7dc005-5eeb-4177-9f9f-d9b5d1358db9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy: 75.15%\n"
     ]
    }
   ],
   "source": [
    "with torch.set_grad_enabled(False): # save memory during inference\n",
    "    print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader, device=DEVICE)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "numpy       1.15.4\n",
      "pandas      0.23.4\n",
      "torch       1.1.0\n",
      "PIL.Image   5.3.0\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%watermark -iv"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "default_view": {},
   "name": "convnet-vgg16.ipynb",
   "provenance": [],
   "version": "0.3.2",
   "views": {}
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": true,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "371px"
   },
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
